package com.zee.admin.interceptor;

import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.zee.common.model.dto.DataScopeDTO;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Field;
import java.util.Map;

/**
 * 数据过滤
 *
 * @author Mark LinZee
 * @email LinZee666@163.com
 */
@Slf4j
public class DataFilterInterceptor implements InnerInterceptor {

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object params,
                            RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        DataScope scope = getDataScope(params);
        // 不进行数据过滤
        if (scope == null || StrUtil.isBlank(scope.getSqlFilter())) {
            return;
        }

        // 拼接新SQL
        String buildSql = getSelect(boundSql.getSql(), scope);

        // 重写SQL
        PluginUtils.mpBoundSql(boundSql).sql(buildSql);
    }

    private DataScope getDataScope(Object params) {
        if (params == null) {
            return null;
        }

        // 判断参数里是否有DataScope对象
        if (params instanceof Map) {
            //params是map类型，dao层要用@Param注解
            Map<?, ?> tempMap = (Map<?, ?>) params;
            for (Map.Entry temp : tempMap.entrySet()) {
                try {
                    Map<?, ?> paramsMap = (Map<?, ?>) temp.getValue();
                    for (Map.Entry entry : paramsMap.entrySet()) {
                        //判断是否DataScope类型
                        if (entry.getValue() != null && entry.getValue() instanceof DataScope) {
                            return (DataScope) entry.getValue();
                        }
                        //判断父类是否是DataScopeDTO
                        if (entry.getValue() != null && DataScopeDTO.class.isAssignableFrom(entry.getValue().getClass().getSuperclass())) {
                            //从父类里拿参数
                            try {
                                Class<?> clazz = entry.getValue().getClass().getSuperclass();
                                Field dataScopeField = clazz.getDeclaredField("dataScope");
                                dataScopeField.setAccessible(true);

                                return (DataScope) dataScopeField.get(entry.getValue());
                            } catch (Exception e) {
                                return null;
                            }
                        }
                    }
                } catch (Exception e) {
                    //程序走到这里，要检查dao层是否用了@Param注解
                    return null;
                }
            }
        } else if (params instanceof DataScope) {
            //params是DataScope类型的情况
            return (DataScope) params;
        } else if (DataScopeDTO.class.isAssignableFrom(params.getClass().getSuperclass())) {
            //params继承至DataScopeDTO的情况
            try {
                Class<?> clazz = params.getClass().getSuperclass();
                Field dataScopeField = clazz.getDeclaredField("dataScope");
                dataScopeField.setAccessible(true);
                return (DataScope) dataScopeField.get(params);
            } catch (Exception e) {
                return null;
            }
        }

        return null;
    }

    private String getSelect(String buildSql, DataScope scope) {
        try {
            Select select = (Select) CCJSqlParserUtil.parse(buildSql);
            PlainSelect plainSelect = (PlainSelect) select.getSelectBody();

            Expression expression = plainSelect.getWhere();
            if (expression == null) {
                plainSelect.setWhere(new StringValue(scope.getSqlFilter()));
            } else {
                AndExpression andExpression = new AndExpression(expression, new StringValue(scope.getSqlFilter()));
                plainSelect.setWhere(andExpression);
            }

            return select.toString().replaceAll("'", "");
        } catch (JSQLParserException e) {
            return buildSql;
        }
    }
}